from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import openai
from collections import defaultdict
from datetime import datetime

class Gsm8kReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("gsm8k")
        self.config.dataset_path = "datasets/gsm8k.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load math problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []

    async def execute_workflow(self, question: str) -> Dict[str, Any]:
        """Execute full prompt engineering workflow"""
        try:
            # Step 0: Extract problem constraints
            constraints = await self._extract_constraints(question)
            if not isinstance(constraints, dict):
                constraints = {
                    "explicit": [],
                    "implicit": [],
                    "notes": "Invalid constraints format"
                }
            self._log_step("step0", "system", {"constraints": constraints})
            
            # Step1: Create root node with initial constraints
            root = self._create_node(
                path=[], 
                method={"description": "Original problem"}, 
                steps=[], 
                constraints={
                    "explicit": constraints.get("explicit", []),
                    "implicit": constraints.get("implicit", [])
                },
                question=question
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step2: Explore solution methods
            methods = await self._explore_solutions(question)
            self._log_step("step2", root.node_id, {"methods": methods})
            
            # Step3: Create method nodes and select the best one
            best_method = max(methods[:self.config.beam_width], key=lambda x: x["score"])
            best_node = self._create_node(
                path=[root.node_id],
                method=best_method,
                steps=best_method.get("steps", []),
                score=best_method.get("score", 0),
                constraints=root.constraints,  # Inherit root constraints
                parent_id=root.node_id,
                question=question
            )
            root.children.append(best_node.node_id)
            self._log_step("step3", best_node.node_id, {"method": best_method})
            
            # Step4: Solve the best node directly (skip classification)
            solution = await self._solve_node(best_node.node_id)
            final_answer = solution["answer"] if solution else "No solution found"
            self._log_step("step4", best_node.node_id, {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs,
                "token_usage": self.llm.token_counts
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
    
    async def _extract_constraints(self, question: str) -> Dict[str, Any]:
        """Extract explicit and implicit constraints from problem"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Analyze this math problem and extract ALL constraints:
    
    problem:{question}
    
    Notice:
    1. Identify explicit constraints (directly stated in the problem)
    2. Derive implicit constraints (e.g., denominators ≠ 0, square roots ≥ 0, log arguments > 0)
    3. Determine domain restrictions based on mathematical principles
    4. Identify range limitations from problem context
    5. Extract physical meaning constraints (e.g., length > 0, probability ∈ [0,1])
    
    Output JSON format:
    {{
        "explicit": ["constraint1", "constraint2"],
        "implicit": ["constraint1", "constraint2"],
        "notes": "Additional analysis notes"
    }}"""
    
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                data = json.loads(response)
                
                if not isinstance(data, dict):
                    print(f"Invalid response type (attempt {attempt+1}): {type(data)}")
                    continue
                    
                constraints = {
                    "explicit": data.get("explicit", []),
                    "implicit": data.get("implicit", []),
                    "notes": data.get("notes", "")
                }
                
                if not (constraints["explicit"] or constraints["implicit"]):
                    print(f"Empty constraints (attempt {attempt+1})")
                    continue
                    
                return constraints
                
            except (json.JSONDecodeError, AttributeError) as e:
                print(f"Parse error (attempt {attempt+1}): {str(e)}")
                continue
        
        print("All retries failed, returning default constraints")
        return {
            "explicit": ["Default explicit constraint"],
            "implicit": ["Default implicit constraint"],
            "notes": "Fallback constraints"
        }
    
    async def _explore_solutions(self, question: str) -> List[Dict[str, Any]]:
        """Step2: Explore diverse solution methods"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Generate 3 distinct solution methods for:

{question}

Notice:
1. Employ different theoretical frameworks (algebraic, geometric, analytical, etc.)
2. Approach from fundamentally different perspectives
3. Vary implementation techniques significantly
4. Consider both conventional and innovative methods
5. Steps can be retained as ideas only, without exact calculations
6. Pay attention to the mathematical expressions in the questions and understand them correctly
7. examine carefully the subject matter

For each method, provide:
- Clear description of the mathematical approach
- Step-by-step implementation plan
- Effectiveness score (0-100) based on:
  * Mathematical rigor
  * Computational feasibility
  * Logical completeness
  * Efficiency

Output JSON format:
{{
    "methods": [
        {{
            "description": "Method description",
            "steps": ["step1", "step2"],
            "score": 0-100,
            "score_reason": "Scoring justification"
        }}
    ]
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                response = response.strip()
                
                if response.startswith("```json"):
                    response = response[7:-3].strip()
                elif response.startswith("```"):
                    response = response[3:-3].strip()
                
                response = response.replace('\\', '\\\\')
                
                data = json.loads(response)
                
                if not isinstance(data, dict) or "methods" not in data:
                    raise ValueError("Invalid structure: missing 'methods' key")
                    
                methods = data["methods"]
                if len(methods) != 3:
                    raise ValueError(f"Expected 3 methods, got {len(methods)}")
                    
                required_keys = {"description", "steps", "score", "score_reason"}
                for method in methods:
                    if not all(k in method for k in required_keys):
                        raise ValueError("Missing required keys in method")
                    if not isinstance(method["steps"], list):
                        raise ValueError("Steps must be a list")
                        
                return sorted(methods, key=lambda x: -x["score"])
                
            except (json.JSONDecodeError, ValueError, KeyError) as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt == self.config.max_retries - 1:
                    print(f"Final failed response: {response}")
                    return []
                continue
                
        return []
    
    async def _solve_node(self, node_id: str) -> Optional[Dict[str, Any]]:
        """Solve individual node"""
        node = self.nodes[node_id]
        root_node = self.nodes[node.path[0]] if node.path else node
        original_question = root_node.method.get("description", "Original problem")
        
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    You are a meticulous mathematical problem solver executing this solution:
    
    Original Problem: {original_question}
    Steps: {node.steps}
    Constraints: {node.constraints}
    
    As an executor, you must:
    1. Follow the provided steps precisely
    2. Explicitly verify all constraints at each step
    3. Show complete mathematical justification
    4. Use proper mathematical notation
    5. Clearly mark the final answer with \\boxed{{}}
    6. Include standalone line: "Final Answer: answer"
    7. Ensure your answer directly responds to the question asked
    8. The final answer should be one exact number
    9. Pay attention to the mathematical expressions in the questions and understand them correctly
    10. examine carefully the subject matter
    
    Additional requirements:
    - Show all intermediate calculations
    - State any assumptions made
    - Verify solution satisfies all constraints
    - Cross-validate critical steps
    - If the question asks for GCD, provide only the GCD as final answer
    - If you calculate intermediate values (like A and B), clearly distinguish them from the final answer"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None

    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text"""
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, text)
        if boxed_matches:
            return boxed_matches[-1]
    
        final_answer_match = re.search(
            r'Final\s+Answer\s*:\s*([^\n]+)', 
            text, 
            re.IGNORECASE
        )
        if final_answer_match:
            return final_answer_match.group(1).strip()
    
        return None
      
    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        if "nodes" in result:
            del result["nodes"]    
        
        verification = {
            "is_correct": False,
            "correct_answer": None,
            "given_answer": result.get("final_answer")
        }
        
        if "answer" in problem:
            correct_answer = None
            if "solution" in problem:
                correct_answer = self._extract_correct_answer(problem["solution"])
            elif "answer" in problem:
                correct_answer = self._extract_correct_answer(problem["answer"])

            verification["correct_answer"] = correct_answer
            
            if correct_answer is not None and "final_answer" in result:
                given = str(result["final_answer"]).strip()
                expected = str(correct_answer).strip()
                
                if len(expected) == 1 and given.endswith(expected):
                    verification["is_correct"] = True
                else:
                    try:
                        given_num = float(given)
                        expected_num = float(expected)
                        if abs(given_num - expected_num) < 1e-10:
                            verification["is_correct"] = True
                    except ValueError:
                        pass
        
        return {
                "question": problem["question"],
                "result": result,
                "verification": verification
        }
    
    def _extract_correct_answer(self, solution: str) -> Optional[str]:
        """Extract correct answer from solution's \boxed{}"""
        hash_pattern = r'####\s*([^\n]+)'
        hash_matches = re.findall(hash_pattern, solution)
        return hash_matches[-1].strip() if hash_matches else None
    
    async def verify_answer(self, problem: Dict[str, Any], final_answer: str) -> bool:
        """Verify if final answer matches correct solution"""
        if "solution" not in problem:
            return False
            
        correct_answer = self._extract_correct_answer(problem["solution"])
        if not correct_answer:
            return False
            
        return str(final_answer).strip() == str(correct_answer).strip()